{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 前言\n",
    "本文包含大量源码和讲解，通过段落和横线分割了各个模块，同时网站配备了侧边栏，帮助大家在各个小节中快速跳转，希望大家阅读完能对BERT有深刻的了解。同时建议通过pycharm、vscode等工具对bert源码进行单步调试，调试到对应的模块再对比看本章节的讲解。\n",
    "\n",
    "涉及到的jupyter可以在[代码库：篇章3-编写一个Transformer模型：BERT，下载](https://github.com/datawhalechina/learn-nlp-with-transformers/tree/main/docs/%E7%AF%87%E7%AB%A03-%E7%BC%96%E5%86%99%E4%B8%80%E4%B8%AATransformer%E6%A8%A1%E5%9E%8B%EF%BC%9ABERT)\n",
    "\n",
    "本篇章将基于H[HuggingFace/Transformers, 48.9k Star](https://github.com/huggingface/transformers)进行学习。本章节的全部代码在[huggingface bert，注意由于版本更新较快，可能存在差别，请以4.4.2版本为准](https://github.com/huggingface/transformers/tree/master/src/transformers/models/bert)HuggingFace 是一家总部位于纽约的聊天机器人初创服务商，很早就捕捉到 BERT 大潮流的信号并着手实现基于 pytorch 的 BERT 模型。这一项目最初名为 pytorch-pretrained-bert，在复现了原始效果的同时，提供了易用的方法以方便在这一强大模型的基础上进行各种玩耍和研究。\n",
    "\n",
    "随着使用人数的增加，这一项目也发展成为一个较大的开源社区，合并了各种预训练语言模型以及增加了 Tensorflow 的实现，并且在 2019 年下半年改名为 Transformers。截止写文章时（2021 年 3 月 30 日）这一项目已经拥有 43k+ 的star，可以说 Transformers 已经成为事实上的 NLP 基本工具。\n",
    "\n",
    "## 本小节主要内容\n",
    "![图：BERT结构](./pictures/3-6-bert.png) 图：BERT结构，来源IrEne: Interpretable Energy Prediction for Transformers\n",
    "\n",
    "本文基于 Transformers 版本 4.4.2（2021 年 3 月 19 日发布）项目中，pytorch 版的 BERT 相关代码，从代码结构、具体实现与原理，以及使用的角度进行分析。\n",
    "主要包含内容：\n",
    "1. BERT Tokenization 分词模型（BertTokenizer）\n",
    "2. BERT Model 本体模型（BertModel）\n",
    "    - BertEmbeddings\n",
    "    - BertEncoder\n",
    "        - BertLayer\n",
    "            - BertAttention\n",
    "            - BertIntermediate\n",
    "            - BertOutput\n",
    "        - BertPooler"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*** \n",
    "## 1-Tokenization分词-BertTokenizer\n",
    "和BERT 有关的 Tokenizer 主要写在[`models/bert/tokenization_bert.py`](https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py)中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import os\n",
    "import unicodedata\n",
    "from typing import List, Optional, Tuple\n",
    "\n",
    "from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace\n",
    "from transformers.utils import logging\n",
    "\n",
    "\n",
    "logger = logging.get_logger(__name__)\n",
    "\n",
    "VOCAB_FILES_NAMES = {\"vocab_file\": \"vocab.txt\"}\n",
    "\n",
    "PRETRAINED_VOCAB_FILES_MAP = {\n",
    "    \"vocab_file\": {\n",
    "        \"bert-base-uncased\": \"https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt\",\n",
    "    }\n",
    "}\n",
    "\n",
    "PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {\n",
    "    \"bert-base-uncased\": 512,\n",
    "}\n",
    "\n",
    "PRETRAINED_INIT_CONFIGURATION = {\n",
    "    \"bert-base-uncased\": {\"do_lower_case\": True},\n",
    "}\n",
    "\n",
    "\n",
    "def load_vocab(vocab_file):\n",
    "    \"\"\"Loads a vocabulary file into a dictionary.\"\"\"\n",
    "    vocab = collections.OrderedDict()\n",
    "    with open(vocab_file, \"r\", encoding=\"utf-8\") as reader:\n",
    "        tokens = reader.readlines()\n",
    "    for index, token in enumerate(tokens):\n",
    "        token = token.rstrip(\"\\n\")\n",
    "        vocab[token] = index\n",
    "    return vocab\n",
    "\n",
    "\n",
    "def whitespace_tokenize(text):\n",
    "    \"\"\"Runs basic whitespace cleaning and splitting on a piece of text.\"\"\"\n",
    "    text = text.strip()\n",
    "    if not text:\n",
    "        return []\n",
    "    tokens = text.split()\n",
    "    return tokens\n",
    "\n",
    "\n",
    "class BertTokenizer(PreTrainedTokenizer):\n",
    "\n",
    "    vocab_files_names = VOCAB_FILES_NAMES\n",
    "    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP\n",
    "    pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION\n",
    "    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        vocab_file,\n",
    "        do_lower_case=True,\n",
    "        do_basic_tokenize=True,\n",
    "        never_split=None,\n",
    "        unk_token=\"[UNK]\",\n",
    "        sep_token=\"[SEP]\",\n",
    "        pad_token=\"[PAD]\",\n",
    "        cls_token=\"[CLS]\",\n",
    "        mask_token=\"[MASK]\",\n",
    "        tokenize_chinese_chars=True,\n",
    "        strip_accents=None,\n",
    "        **kwargs\n",
    "    ):\n",
    "        super().__init__(\n",
    "            do_lower_case=do_lower_case,\n",
    "            do_basic_tokenize=do_basic_tokenize,\n",
    "            never_split=never_split,\n",
    "            unk_token=unk_token,\n",
    "            sep_token=sep_token,\n",
    "            pad_token=pad_token,\n",
    "            cls_token=cls_token,\n",
    "            mask_token=mask_token,\n",
    "            tokenize_chinese_chars=tokenize_chinese_chars,\n",
    "            strip_accents=strip_accents,\n",
    "            **kwargs,\n",
    "        )\n",
    "\n",
    "        if not os.path.isfile(vocab_file):\n",
    "            raise ValueError(\n",
    "                f\"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained \"\n",
    "                \"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`\"\n",
    "            )\n",
    "        self.vocab = load_vocab(vocab_file)\n",
    "        self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])\n",
    "        self.do_basic_tokenize = do_basic_tokenize\n",
    "        if do_basic_tokenize:\n",
    "            self.basic_tokenizer = BasicTokenizer(\n",
    "                do_lower_case=do_lower_case,\n",
    "                never_split=never_split,\n",
    "                tokenize_chinese_chars=tokenize_chinese_chars,\n",
    "                strip_accents=strip_accents,\n",
    "            )\n",
    "        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)\n",
    "\n",
    "    @property\n",
    "    def do_lower_case(self):\n",
    "        return self.basic_tokenizer.do_lower_case\n",
    "\n",
    "    @property\n",
    "    def vocab_size(self):\n",
    "        return len(self.vocab)\n",
    "\n",
    "    def get_vocab(self):\n",
    "        return dict(self.vocab, **self.added_tokens_encoder)\n",
    "\n",
    "    def _tokenize(self, text):\n",
    "        split_tokens = []\n",
    "        if self.do_basic_tokenize:\n",
    "            for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):\n",
    "\n",
    "                # If the token is part of the never_split set\n",
    "                if token in self.basic_tokenizer.never_split:\n",
    "                    split_tokens.append(token)\n",
    "                else:\n",
    "                    split_tokens += self.wordpiece_tokenizer.tokenize(token)\n",
    "        else:\n",
    "            split_tokens = self.wordpiece_tokenizer.tokenize(text)\n",
    "        return split_tokens\n",
    "\n",
    "    def _convert_token_to_id(self, token):\n",
    "        \"\"\"Converts a token (str) in an id using the vocab.\"\"\"\n",
    "        return self.vocab.get(token, self.vocab.get(self.unk_token))\n",
    "\n",
    "    def _convert_id_to_token(self, index):\n",
    "        \"\"\"Converts an index (integer) in a token (str) using the vocab.\"\"\"\n",
    "        return self.ids_to_tokens.get(index, self.unk_token)\n",
    "\n",
    "    def convert_tokens_to_string(self, tokens):\n",
    "        \"\"\"Converts a sequence of tokens (string) in a single string.\"\"\"\n",
    "        out_string = \" \".join(tokens).replace(\" ##\", \"\").strip()\n",
    "        return out_string\n",
    "\n",
    "    def build_inputs_with_special_tokens(\n",
    "        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n",
    "    ) -> List[int]:\n",
    "        \"\"\"\n",
    "        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and\n",
    "        adding special tokens. A BERT sequence has the following format:\n",
    "        - single sequence: ``[CLS] X [SEP]``\n",
    "        - pair of sequences: ``[CLS] A [SEP] B [SEP]``\n",
    "        Args:\n",
    "            token_ids_0 (:obj:`List[int]`):\n",
    "                List of IDs to which the special tokens will be added.\n",
    "            token_ids_1 (:obj:`List[int]`, `optional`):\n",
    "                Optional second list of IDs for sequence pairs.\n",
    "        Returns:\n",
    "            :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.\n",
    "        \"\"\"\n",
    "        if token_ids_1 is None:\n",
    "            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]\n",
    "        cls = [self.cls_token_id]\n",
    "        sep = [self.sep_token_id]\n",
    "        return cls + token_ids_0 + sep + token_ids_1 + sep\n",
    "\n",
    "    def get_special_tokens_mask(\n",
    "        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False\n",
    "    ) -> List[int]:\n",
    "        \"\"\"\n",
    "        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding\n",
    "        special tokens using the tokenizer ``prepare_for_model`` method.\n",
    "        Args:\n",
    "            token_ids_0 (:obj:`List[int]`):\n",
    "                List of IDs.\n",
    "            token_ids_1 (:obj:`List[int]`, `optional`):\n",
    "                Optional second list of IDs for sequence pairs.\n",
    "            already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):\n",
    "                Whether or not the token list is already formatted with special tokens for the model.\n",
    "        Returns:\n",
    "            :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.\n",
    "        \"\"\"\n",
    "\n",
    "        if already_has_special_tokens:\n",
    "            return super().get_special_tokens_mask(\n",
    "                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True\n",
    "            )\n",
    "\n",
    "        if token_ids_1 is not None:\n",
    "            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]\n",
    "        return [1] + ([0] * len(token_ids_0)) + [1]\n",
    "\n",
    "    def create_token_type_ids_from_sequences(\n",
    "        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None\n",
    "    ) -> List[int]:\n",
    "        \"\"\"\n",
    "        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence\n",
    "        pair mask has the following format:\n",
    "        ::\n",
    "            0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1\n",
    "            | first sequence    | second sequence |\n",
    "        If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).\n",
    "        Args:\n",
    "            token_ids_0 (:obj:`List[int]`):\n",
    "                List of IDs.\n",
    "            token_ids_1 (:obj:`List[int]`, `optional`):\n",
    "                Optional second list of IDs for sequence pairs.\n",
    "        Returns:\n",
    "            :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given\n",
    "            sequence(s).\n",
    "        \"\"\"\n",
    "        sep = [self.sep_token_id]\n",
    "        cls = [self.cls_token_id]\n",
    "        if token_ids_1 is None:\n",
    "            return len(cls + token_ids_0 + sep) * [0]\n",
    "        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]\n",
    "\n",
    "    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:\n",
    "        index = 0\n",
    "        if os.path.isdir(save_directory):\n",
    "            vocab_file = os.path.join(\n",
    "                save_directory, (filename_prefix + \"-\" if filename_prefix else \"\") + VOCAB_FILES_NAMES[\"vocab_file\"]\n",
    "            )\n",
    "        else:\n",
    "            vocab_file = (filename_prefix + \"-\" if filename_prefix else \"\") + save_directory\n",
    "        with open(vocab_file, \"w\", encoding=\"utf-8\") as writer:\n",
    "            for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):\n",
    "                if index != token_index:\n",
    "                    logger.warning(\n",
    "                        f\"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive.\"\n",
    "                        \" Please check that the vocabulary is not corrupted!\"\n",
    "                    )\n",
    "                    index = token_index\n",
    "                writer.write(token + \"\\n\")\n",
    "                index += 1\n",
    "        return (vocab_file,)\n",
    "\n",
    "\n",
    "class BasicTokenizer(object):\n",
    "\n",
    "    def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):\n",
    "        if never_split is None:\n",
    "            never_split = []\n",
    "        self.do_lower_case = do_lower_case\n",
    "        self.never_split = set(never_split)\n",
    "        self.tokenize_chinese_chars = tokenize_chinese_chars\n",
    "        self.strip_accents = strip_accents\n",
    "\n",
    "    def tokenize(self, text, never_split=None):\n",
    "        \"\"\"\n",
    "        Basic Tokenization of a piece of text. Split on \"white spaces\" only, for sub-word tokenization, see\n",
    "        WordPieceTokenizer.\n",
    "        Args:\n",
    "            **never_split**: (`optional`) list of str\n",
    "                Kept for backward compatibility purposes. Now implemented directly at the base class level (see\n",
    "                :func:`PreTrainedTokenizer.tokenize`) List of token not to split.\n",
    "        \"\"\"\n",
    "        # union() returns a new set by concatenating the two sets.\n",
    "        never_split = self.never_split.union(set(never_split)) if never_split else self.never_split\n",
    "        text = self._clean_text(text)\n",
    "\n",
    "        # This was added on November 1st, 2018 for the multilingual and Chinese\n",
    "        # models. This is also applied to the English models now, but it doesn't\n",
    "        # matter since the English models were not trained on any Chinese data\n",
    "        # and generally don't have any Chinese data in them (there are Chinese\n",
    "        # characters in the vocabulary because Wikipedia does have some Chinese\n",
    "        # words in the English Wikipedia.).\n",
    "        if self.tokenize_chinese_chars:\n",
    "            text = self._tokenize_chinese_chars(text)\n",
    "        orig_tokens = whitespace_tokenize(text)\n",
    "        split_tokens = []\n",
    "        for token in orig_tokens:\n",
    "            if token not in never_split:\n",
    "                if self.do_lower_case:\n",
    "                    token = token.lower()\n",
    "                    if self.strip_accents is not False:\n",
    "                        token = self._run_strip_accents(token)\n",
    "                elif self.strip_accents:\n",
    "                    token = self._run_strip_accents(token)\n",
    "            split_tokens.extend(self._run_split_on_punc(token, never_split))\n",
    "\n",
    "        output_tokens = whitespace_tokenize(\" \".join(split_tokens))\n",
    "        return output_tokens\n",
    "\n",
    "    def _run_strip_accents(self, text):\n",
    "        \"\"\"Strips accents from a piece of text.\"\"\"\n",
    "        text = unicodedata.normalize(\"NFD\", text)\n",
    "        output = []\n",
    "        for char in text:\n",
    "            cat = unicodedata.category(char)\n",
    "            if cat == \"Mn\":\n",
    "                continue\n",
    "            output.append(char)\n",
    "        return \"\".join(output)\n",
    "\n",
    "    def _run_split_on_punc(self, text, never_split=None):\n",
    "        \"\"\"Splits punctuation on a piece of text.\"\"\"\n",
    "        if never_split is not None and text in never_split:\n",
    "            return [text]\n",
    "        chars = list(text)\n",
    "        i = 0\n",
    "        start_new_word = True\n",
    "        output = []\n",
    "        while i < len(chars):\n",
    "            char = chars[i]\n",
    "            if _is_punctuation(char):\n",
    "                output.append([char])\n",
    "                start_new_word = True\n",
    "            else:\n",
    "                if start_new_word:\n",
    "                    output.append([])\n",
    "                start_new_word = False\n",
    "                output[-1].append(char)\n",
    "            i += 1\n",
    "\n",
    "        return [\"\".join(x) for x in output]\n",
    "\n",
    "    def _tokenize_chinese_chars(self, text):\n",
    "        \"\"\"Adds whitespace around any CJK character.\"\"\"\n",
    "        output = []\n",
    "        for char in text:\n",
    "            cp = ord(char)\n",
    "            if self._is_chinese_char(cp):\n",
    "                output.append(\" \")\n",
    "                output.append(char)\n",
    "                output.append(\" \")\n",
    "            else:\n",
    "                output.append(char)\n",
    "        return \"\".join(output)\n",
    "\n",
    "    def _is_chinese_char(self, cp):\n",
    "        \"\"\"Checks whether CP is the codepoint of a CJK character.\"\"\"\n",
    "        # This defines a \"chinese character\" as anything in the CJK Unicode block:\n",
    "        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)\n",
    "        #\n",
    "        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,\n",
    "        # despite its name. The modern Korean Hangul alphabet is a different block,\n",
    "        # as is Japanese Hiragana and Katakana. Those alphabets are used to write\n",
    "        # space-separated words, so they are not treated specially and handled\n",
    "        # like the all of the other languages.\n",
    "        if (\n",
    "            (cp >= 0x4E00 and cp <= 0x9FFF)\n",
    "            or (cp >= 0x3400 and cp <= 0x4DBF)  #\n",
    "            or (cp >= 0x20000 and cp <= 0x2A6DF)  #\n",
    "            or (cp >= 0x2A700 and cp <= 0x2B73F)  #\n",
    "            or (cp >= 0x2B740 and cp <= 0x2B81F)  #\n",
    "            or (cp >= 0x2B820 and cp <= 0x2CEAF)  #\n",
    "            or (cp >= 0xF900 and cp <= 0xFAFF)\n",
    "            or (cp >= 0x2F800 and cp <= 0x2FA1F)  #\n",
    "        ):  #\n",
    "            return True\n",
    "\n",
    "        return False\n",
    "\n",
    "    def _clean_text(self, text):\n",
    "        \"\"\"Performs invalid character removal and whitespace cleanup on text.\"\"\"\n",
    "        output = []\n",
    "        for char in text:\n",
    "            cp = ord(char)\n",
    "            if cp == 0 or cp == 0xFFFD or _is_control(char):\n",
    "                continue\n",
    "            if _is_whitespace(char):\n",
    "                output.append(\" \")\n",
    "            else:\n",
    "                output.append(char)\n",
    "        return \"\".join(output)\n",
    "\n",
    "\n",
    "class WordpieceTokenizer(object):\n",
    "    \"\"\"Runs WordPiece tokenization.\"\"\"\n",
    "\n",
    "    def __init__(self, vocab, unk_token, max_input_chars_per_word=100):\n",
    "        self.vocab = vocab\n",
    "        self.unk_token = unk_token\n",
    "        self.max_input_chars_per_word = max_input_chars_per_word\n",
    "\n",
    "    def tokenize(self, text):\n",
    "        \"\"\"\n",
    "        Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform\n",
    "        tokenization using the given vocabulary.\n",
    "        For example, :obj:`input = \"unaffable\"` wil return as output :obj:`[\"un\", \"##aff\", \"##able\"]`.\n",
    "        Args:\n",
    "          text: A single token or whitespace separated tokens. This should have\n",
    "            already been passed through `BasicTokenizer`.\n",
    "        Returns:\n",
    "          A list of wordpiece tokens.\n",
    "        \"\"\"\n",
    "\n",
    "        output_tokens = []\n",
    "        for token in whitespace_tokenize(text):\n",
    "            chars = list(token)\n",
    "            if len(chars) > self.max_input_chars_per_word:\n",
    "                output_tokens.append(self.unk_token)\n",
    "                continue\n",
    "\n",
    "            is_bad = False\n",
    "            start = 0\n",
    "            sub_tokens = []\n",
    "            while start < len(chars):\n",
    "                end = len(chars)\n",
    "                cur_substr = None\n",
    "                while start < end:\n",
    "                    substr = \"\".join(chars[start:end])\n",
    "                    if start > 0:\n",
    "                        substr = \"##\" + substr\n",
    "                    if substr in self.vocab:\n",
    "                        cur_substr = substr\n",
    "                        break\n",
    "                    end -= 1\n",
    "                if cur_substr is None:\n",
    "                    is_bad = True\n",
    "                    break\n",
    "                sub_tokens.append(cur_substr)\n",
    "                start = end\n",
    "\n",
    "            if is_bad:\n",
    "                output_tokens.append(self.unk_token)\n",
    "            else:\n",
    "                output_tokens.extend(sub_tokens)\n",
    "        return output_tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "class BertTokenizer(PreTrainedTokenizer):\n",
    "    \"\"\"\n",
    "    Construct a BERT tokenizer. Based on WordPiece.\n",
    "\n",
    "    This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.\n",
    "    Users should refer to this superclass for more information regarding those methods.\n",
    "    ...\n",
    "    \"\"\"\n",
    "```\n",
    "\n",
    "`BertTokenizer` 是基于`BasicTokenizer`和`WordPieceTokenizer`的分词器：\n",
    "- BasicTokenizer负责处理的第一步——按标点、空格等分割句子，并处理是否统一小写，以及清理非法字符。\n",
    "    - 对于中文字符，通过预处理（加空格）来按字分割；\n",
    "    - 同时可以通过never_split指定对某些词不进行分割；\n",
    "    - 这一步是可选的（默认执行）。\n",
    "- WordPieceTokenizer在词的基础上，进一步将词分解为子词（subword）。\n",
    "    - subword 介于 char 和 word 之间，既在一定程度保留了词的含义，又能够照顾到英文中单复数、时态导致的词表爆炸和未登录词的 OOV（Out-Of-Vocabulary）问题，将词根与时态词缀等分割出来，从而减小词表，也降低了训练难度；\n",
    "    - 例如，tokenizer 这个词就可以拆解为“token”和“##izer”两部分，注意后面一个词的“##”表示接在前一个词后面。\n",
    "BertTokenizer 有以下常用方法：\n",
    "- from_pretrained：从包含词表文件（vocab.txt）的目录中初始化一个分词器；\n",
    "- tokenize：将文本（词或者句子）分解为子词列表；\n",
    "- convert_tokens_to_ids：将子词列表转化为子词对应下标的列表；\n",
    "- convert_ids_to_tokens ：与上一个相反；\n",
    "- convert_tokens_to_string：将 subword 列表按“##”拼接回词或者句子；\n",
    "- encode：对于单个句子输入，分解词并加入特殊词形成“[CLS], x, [SEP]”的结构并转换为词表对应下标的列表；对于两个句子输入（多个句子只取前两个），分解词并加入特殊词形成“[CLS], x1, [SEP], x2, [SEP]”的结构并转换为下标列表；\n",
    "- decode：可以将 encode 方法的输出变为完整句子。\n",
    "以及，类自身的方法："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: 100%|██████████| 232k/232k [00:00<00:00, 698kB/s]\n",
      "Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 11.1kB/s]\n",
      "Downloading: 100%|██████████| 466k/466k [00:00<00:00, 863kB/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input_ids': [101, 1045, 2066, 3019, 2653, 27673, 999, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bt = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "bt('I like natural language progressing!')\n",
    "# {'input_ids': [101, 1045, 2066, 3019, 2653, 27673, 999, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*** \n",
    "## 2-Model-BertModel\n",
    "和 BERT 模型有关的代码主要写在[`/models/bert/modeling_bert.py`](https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/modeling_bert.py)中，这一份代码有一千多行，包含 BERT 模型的基本结构和基于它的微调模型等。\n",
    "\n",
    "下面从 BERT 模型本体入手分析：\n",
    "```\n",
    "class BertModel(BertPreTrainedModel):\n",
    "    \"\"\"\n",
    "\n",
    "    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of\n",
    "    cross-attention is added between the self-attention layers, following the architecture described in `Attention is\n",
    "    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,\n",
    "    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.\n",
    "\n",
    "    To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration\n",
    "    set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`\n",
    "    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an\n",
    "    input to the forward pass.\n",
    "    \"\"\" \n",
    "```\n",
    "BertModel 主要为 transformer encoder 结构，包含三个部分：\n",
    "1. embeddings，即BertEmbeddings类的实体，根据单词符号获取对应的向量表示；\n",
    "2. encoder，即BertEncoder类的实体；\n",
    "3. pooler，即BertPooler类的实体，这一部分是可选的。\n",
    "\n",
    "**注意 BertModel 也可以配置为 Decoder，不过下文中不包含对这一部分的讨论。**\n",
    "\n",
    "下面将介绍 BertModel 的前向传播过程中各个参数的含义以及返回值：\n",
    "```\n",
    "def forward(\n",
    "        self,\n",
    "        input_ids=None,\n",
    "        attention_mask=None,\n",
    "        token_type_ids=None,\n",
    "        position_ids=None,\n",
    "        head_mask=None,\n",
    "        inputs_embeds=None,\n",
    "        encoder_hidden_states=None,\n",
    "        encoder_attention_mask=None,\n",
    "        past_key_values=None,\n",
    "        use_cache=None,\n",
    "        output_attentions=None,\n",
    "        output_hidden_states=None,\n",
    "        return_dict=None,\n",
    "    ): ...\n",
    "```\n",
    "- input_ids：经过 tokenizer 分词后的 subword 对应的下标列表；\n",
    "- attention_mask：在 self-attention 过程中，这一块 mask 用于标记 subword 所处句子和 padding 的区别，将 padding 部分填充为 0；\n",
    "- token_type_ids：标记 subword 当前所处句子（第一句/第二句/ padding）；\n",
    "- position_ids：标记当前词所在句子的位置下标；\n",
    "- head_mask：用于将某些层的某些注意力计算无效化；\n",
    "- inputs_embeds：如果提供了，那就不需要input_ids，跨过 embedding lookup 过程直接作为 Embedding 进入 Encoder 计算；\n",
    "- encoder_hidden_states：这一部分在 BertModel 配置为 decoder 时起作用，将执行 cross-attention 而不是 self-attention；\n",
    "- encoder_attention_mask：同上，在 cross-attention 中用于标记 encoder 端输入的 padding；\n",
    "- past_key_values：这个参数貌似是把预先计算好的 K-V 乘积传入，以降低 cross-attention 的开销（因为原本这部分是重复计算）；\n",
    "- use_cache：将保存上一个参数并传回，加速 decoding；\n",
    "- output_attentions：是否返回中间每层的 attention 输出；\n",
    "- output_hidden_states：是否返回中间每层的输出；\n",
    "- return_dict：是否按键值对的形式（ModelOutput 类，也可以当作 tuple 用）返回输出，默认为真。\n",
    "\n",
    "**注意，这里的 head_mask 对注意力计算的无效化，和下文提到的注意力头剪枝不同，而仅仅把某些注意力的计算结果给乘以这一系数。**\n",
    "\n",
    "输出部分如下：\n",
    "```\n",
    "# BertModel的前向传播返回部分\n",
    "        if not return_dict:\n",
    "            return (sequence_output, pooled_output) + encoder_outputs[1:]\n",
    "\n",
    "        return BaseModelOutputWithPoolingAndCrossAttentions(\n",
    "            last_hidden_state=sequence_output,\n",
    "            pooler_output=pooled_output,\n",
    "            past_key_values=encoder_outputs.past_key_values,\n",
    "            hidden_states=encoder_outputs.hidden_states,\n",
    "            attentions=encoder_outputs.attentions,\n",
    "            cross_attentions=encoder_outputs.cross_attentions,\n",
    "        )\n",
    "```\n",
    "可以看出，返回值不但包含了 encoder 和 pooler 的输出，也包含了其他指定输出的部分（hidden_states 和 attention 等，这一部分在encoder_outputs[1:]）方便取用：\n",
    "\n",
    "```\n",
    "        # BertEncoder的前向传播返回部分，即上面的encoder_outputs\n",
    "        if not return_dict:\n",
    "            return tuple(\n",
    "                v\n",
    "                for v in [\n",
    "                    hidden_states,\n",
    "                    next_decoder_cache,\n",
    "                    all_hidden_states,\n",
    "                    all_self_attentions,\n",
    "                    all_cross_attentions,\n",
    "                ]\n",
    "                if v is not None\n",
    "            )\n",
    "        return BaseModelOutputWithPastAndCrossAttentions(\n",
    "            last_hidden_state=hidden_states,\n",
    "            past_key_values=next_decoder_cache,\n",
    "            hidden_states=all_hidden_states,\n",
    "            attentions=all_self_attentions,\n",
    "            cross_attentions=all_cross_attentions,\n",
    "        )\n",
    "```\n",
    "\n",
    "此外，BertModel 还有以下的方法，方便 BERT 玩家进行各种操作：\n",
    "\n",
    "- get_input_embeddings：提取 embedding 中的 word_embeddings 即词向量部分；\n",
    "- set_input_embeddings：为 embedding 中的 word_embeddings 赋值；\n",
    "- _prune_heads：提供了将注意力头剪枝的函数，输入为{layer_num: list of heads to prune in this layer}的字典，可以将指定层的某些注意力头剪枝。\n",
    "\n",
    "** 剪枝是一个复杂的操作，需要将保留的注意力头部分的 Wq、Kq、Vq 和拼接后全连接部分的权重拷贝到一个新的较小的权重矩阵（注意先禁止 grad 再拷贝），并实时记录被剪掉的头以防下标出错。具体参考BertAttention部分的prune_heads方法.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers.models.bert.modeling_bert import *\n",
    "class BertModel(BertPreTrainedModel):\n",
    "    \"\"\"\n",
    "    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of\n",
    "    cross-attention is added between the self-attention layers, following the architecture described in `Attention is\n",
    "    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,\n",
    "    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.\n",
    "    To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration\n",
    "    set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`\n",
    "    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an\n",
    "    input to the forward pass.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, config, add_pooling_layer=True):\n",
    "        super().__init__(config)\n",
    "        self.config = config\n",
    "\n",
    "        self.embeddings = BertEmbeddings(config)\n",
    "        self.encoder = BertEncoder(config)\n",
    "\n",
    "        self.pooler = BertPooler(config) if add_pooling_layer else None\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "    def get_input_embeddings(self):\n",
    "        return self.embeddings.word_embeddings\n",
    "\n",
    "    def set_input_embeddings(self, value):\n",
    "        self.embeddings.word_embeddings = value\n",
    "\n",
    "    def _prune_heads(self, heads_to_prune):\n",
    "        \"\"\"\n",
    "        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base\n",
    "        class PreTrainedModel\n",
    "        \"\"\"\n",
    "        for layer, heads in heads_to_prune.items():\n",
    "            self.encoder.layer[layer].attention.prune_heads(heads)\n",
    "\n",
    "    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n",
    "    @add_code_sample_docstrings(\n",
    "        tokenizer_class=_TOKENIZER_FOR_DOC,\n",
    "        checkpoint=_CHECKPOINT_FOR_DOC,\n",
    "        output_type=BaseModelOutputWithPoolingAndCrossAttentions,\n",
    "        config_class=_CONFIG_FOR_DOC,\n",
    "    )\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids=None,\n",
    "        attention_mask=None,\n",
    "        token_type_ids=None,\n",
    "        position_ids=None,\n",
    "        head_mask=None,\n",
    "        inputs_embeds=None,\n",
    "        encoder_hidden_states=None,\n",
    "        encoder_attention_mask=None,\n",
    "        past_key_values=None,\n",
    "        use_cache=None,\n",
    "        output_attentions=None,\n",
    "        output_hidden_states=None,\n",
    "        return_dict=None,\n",
    "    ):\n",
    "        r\"\"\"\n",
    "        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):\n",
    "            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if\n",
    "            the model is configured as a decoder.\n",
    "        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n",
    "            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in\n",
    "            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:\n",
    "            - 1 for tokens that are **not masked**,\n",
    "            - 0 for tokens that are **masked**.\n",
    "        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n",
    "            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n",
    "            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`\n",
    "            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`\n",
    "            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.\n",
    "        use_cache (:obj:`bool`, `optional`):\n",
    "            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up\n",
    "            decoding (see :obj:`past_key_values`).\n",
    "        \"\"\"\n",
    "        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n",
    "        output_hidden_states = (\n",
    "            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n",
    "        )\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "        if self.config.is_decoder:\n",
    "            use_cache = use_cache if use_cache is not None else self.config.use_cache\n",
    "        else:\n",
    "            use_cache = False\n",
    "\n",
    "        if input_ids is not None and inputs_embeds is not None:\n",
    "            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n",
    "        elif input_ids is not None:\n",
    "            input_shape = input_ids.size()\n",
    "            batch_size, seq_length = input_shape\n",
    "        elif inputs_embeds is not None:\n",
    "            input_shape = inputs_embeds.size()[:-1]\n",
    "            batch_size, seq_length = input_shape\n",
    "        else:\n",
    "            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n",
    "\n",
    "        device = input_ids.device if input_ids is not None else inputs_embeds.device\n",
    "\n",
    "        # past_key_values_length\n",
    "        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0\n",
    "\n",
    "        if attention_mask is None:\n",
    "            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)\n",
    "\n",
    "        if token_type_ids is None:\n",
    "            if hasattr(self.embeddings, \"token_type_ids\"):\n",
    "                buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]\n",
    "                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)\n",
    "                token_type_ids = buffered_token_type_ids_expanded\n",
    "            else:\n",
    "                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)\n",
    "\n",
    "        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n",
    "        # ourselves in which case we just need to make it broadcastable to all heads.\n",
    "        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)\n",
    "\n",
    "        # If a 2D or 3D attention mask is provided for the cross-attention\n",
    "        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]\n",
    "        if self.config.is_decoder and encoder_hidden_states is not None:\n",
    "            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()\n",
    "            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)\n",
    "            if encoder_attention_mask is None:\n",
    "                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)\n",
    "            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)\n",
    "        else:\n",
    "            encoder_extended_attention_mask = None\n",
    "\n",
    "        # Prepare head mask if needed\n",
    "        # 1.0 in head_mask indicate we keep the head\n",
    "        # attention_probs has shape bsz x n_heads x N x N\n",
    "        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]\n",
    "        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]\n",
    "        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)\n",
    "\n",
    "        embedding_output = self.embeddings(\n",
    "            input_ids=input_ids,\n",
    "            position_ids=position_ids,\n",
    "            token_type_ids=token_type_ids,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            past_key_values_length=past_key_values_length,\n",
    "        )\n",
    "        encoder_outputs = self.encoder(\n",
    "            embedding_output,\n",
    "            attention_mask=extended_attention_mask,\n",
    "            head_mask=head_mask,\n",
    "            encoder_hidden_states=encoder_hidden_states,\n",
    "            encoder_attention_mask=encoder_extended_attention_mask,\n",
    "            past_key_values=past_key_values,\n",
    "            use_cache=use_cache,\n",
    "            output_attentions=output_attentions,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            return_dict=return_dict,\n",
    "        )\n",
    "        sequence_output = encoder_outputs[0]\n",
    "        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None\n",
    "\n",
    "        if not return_dict:\n",
    "            return (sequence_output, pooled_output) + encoder_outputs[1:]\n",
    "\n",
    "        return BaseModelOutputWithPoolingAndCrossAttentions(\n",
    "            last_hidden_state=sequence_output,\n",
    "            pooler_output=pooled_output,\n",
    "            past_key_values=encoder_outputs.past_key_values,\n",
    "            hidden_states=encoder_outputs.hidden_states,\n",
    "            attentions=encoder_outputs.attentions,\n",
    "            cross_attentions=encoder_outputs.cross_attentions,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "***\n",
    "### 2.1-BertEmbeddings\n",
    "包含三个部分求和得到：\n",
    "![Bert-embedding](./pictures/3-0-embedding.png) 图：Bert-embedding\n",
    "\n",
    "1. word_embeddings，上文中 subword 对应的嵌入。\n",
    "2. token_type_embeddings，用于表示当前词所在的句子，辅助区别句子与 padding、句子对间的差异。\n",
    "3。 position_embeddings，句子中每个词的位置嵌入，用于区别词的顺序。和 transformer 论文中的设计不同，这一块是训练出来的，而不是通过 Sinusoidal 函数计算得到的固定嵌入。一般认为这种实现不利于拓展性（难以直接迁移到更长的句子中）。\n",
    "\n",
    "三个 embedding 不带权重相加，并通过一层 LayerNorm+dropout 后输出，其大小为(batch_size, sequence_length, hidden_size)。\n",
    "\n",
    "** [这里为什么要用 LayerNorm+Dropout 呢？为什么要用 LayerNorm 而不是 BatchNorm？可以参考一个不错的回答：transformer 为什么使用 layer normalization，而不是其他的归一化方法？](https://www.zhihu.com/question/395811291/answer/1260290120)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BertEmbeddings(nn.Module):\n",
    "    \"\"\"Construct the embeddings from word, position and token_type embeddings.\"\"\"\n",
    "\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)\n",
    "        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)\n",
    "        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)\n",
    "\n",
    "        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load\n",
    "        # any TensorFlow checkpoint file\n",
    "        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n",
    "        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
    "        # position_ids (1, len position emb) is contiguous in memory and exported when serialized\n",
    "        self.position_embedding_type = getattr(config, \"position_embedding_type\", \"absolute\")\n",
    "        self.register_buffer(\"position_ids\", torch.arange(config.max_position_embeddings).expand((1, -1)))\n",
    "        if version.parse(torch.__version__) > version.parse(\"1.6.0\"):\n",
    "            self.register_buffer(\n",
    "                \"token_type_ids\",\n",
    "                torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),\n",
    "                persistent=False,\n",
    "            )\n",
    "\n",
    "    def forward(\n",
    "        self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0\n",
    "    ):\n",
    "        if input_ids is not None:\n",
    "            input_shape = input_ids.size()\n",
    "        else:\n",
    "            input_shape = inputs_embeds.size()[:-1]\n",
    "\n",
    "        seq_length = input_shape[1]\n",
    "\n",
    "        if position_ids is None:\n",
    "            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]\n",
    "\n",
    "        # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs\n",
    "        # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves\n",
    "        # issue #5664\n",
    "        if token_type_ids is None:\n",
    "            if hasattr(self, \"token_type_ids\"):\n",
    "                buffered_token_type_ids = self.token_type_ids[:, :seq_length]\n",
    "                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)\n",
    "                token_type_ids = buffered_token_type_ids_expanded\n",
    "            else:\n",
    "                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)\n",
    "\n",
    "        if inputs_embeds is None:\n",
    "            inputs_embeds = self.word_embeddings(input_ids)\n",
    "        token_type_embeddings = self.token_type_embeddings(token_type_ids)\n",
    "\n",
    "        embeddings = inputs_embeds + token_type_embeddings\n",
    "        if self.position_embedding_type == \"absolute\":\n",
    "            position_embeddings = self.position_embeddings(position_ids)\n",
    "            embeddings += position_embeddings\n",
    "        embeddings = self.LayerNorm(embeddings)\n",
    "        embeddings = self.dropout(embeddings)\n",
    "        return embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*** \n",
    "### 2.2-BertEncoder\n",
    "\n",
    "包含多层 BertLayer，这一块本身没有特别需要说明的地方，不过有一个细节值得参考：利用 gradient checkpointing 技术以降低训练时的显存占用。\n",
    "\n",
    "**gradient checkpointing 即梯度检查点，通过减少保存的计算图节点压缩模型占用空间，但是在计算梯度的时候需要重新计算没有存储的值，参考论文《Training Deep Nets with Sublinear Memory Cost》，过程如下示意图**\n",
    "![gradient-checkpointing](./pictures/3-1-gradient-checkpointing.gif) 图：gradient-checkpointing\n",
    "\n",
    "在 BertEncoder 中，gradient checkpoint 是通过 torch.utils.checkpoint.checkpoint 实现的，使用起来比较方便，可以参考文档：torch.utils.checkpoint - PyTorch 1.8.1 documentation，这一机制的具体实现比较复杂，在此不作展开。\n",
    "\n",
    "再往深一层走，就进入了 Encoder 的某一层："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BertEncoder(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.config = config\n",
    "        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        hidden_states,\n",
    "        attention_mask=None,\n",
    "        head_mask=None,\n",
    "        encoder_hidden_states=None,\n",
    "        encoder_attention_mask=None,\n",
    "        past_key_values=None,\n",
    "        use_cache=None,\n",
    "        output_attentions=False,\n",
    "        output_hidden_states=False,\n",
    "        return_dict=True,\n",
    "    ):\n",
    "        all_hidden_states = () if output_hidden_states else None\n",
    "        all_self_attentions = () if output_attentions else None\n",
    "        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None\n",
    "\n",
    "        next_decoder_cache = () if use_cache else None\n",
    "        for i, layer_module in enumerate(self.layer):\n",
    "            if output_hidden_states:\n",
    "                all_hidden_states = all_hidden_states + (hidden_states,)\n",
    "\n",
    "            layer_head_mask = head_mask[i] if head_mask is not None else None\n",
    "            past_key_value = past_key_values[i] if past_key_values is not None else None\n",
    "\n",
    "            if getattr(self.config, \"gradient_checkpointing\", False) and self.training:\n",
    "\n",
    "                if use_cache:\n",
    "                    logger.warning(\n",
    "                        \"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting \"\n",
    "                        \"`use_cache=False`...\"\n",
    "                    )\n",
    "                    use_cache = False\n",
    "\n",
    "                def create_custom_forward(module):\n",
    "                    def custom_forward(*inputs):\n",
    "                        return module(*inputs, past_key_value, output_attentions)\n",
    "\n",
    "                    return custom_forward\n",
    "\n",
    "                layer_outputs = torch.utils.checkpoint.checkpoint(\n",
    "                    create_custom_forward(layer_module),\n",
    "                    hidden_states,\n",
    "                    attention_mask,\n",
    "                    layer_head_mask,\n",
    "                    encoder_hidden_states,\n",
    "                    encoder_attention_mask,\n",
    "                )\n",
    "            else:\n",
    "                layer_outputs = layer_module(\n",
    "                    hidden_states,\n",
    "                    attention_mask,\n",
    "                    layer_head_mask,\n",
    "                    encoder_hidden_states,\n",
    "                    encoder_attention_mask,\n",
    "                    past_key_value,\n",
    "                    output_attentions,\n",
    "                )\n",
    "\n",
    "            hidden_states = layer_outputs[0]\n",
    "            if use_cache:\n",
    "                next_decoder_cache += (layer_outputs[-1],)\n",
    "            if output_attentions:\n",
    "                all_self_attentions = all_self_attentions + (layer_outputs[1],)\n",
    "                if self.config.add_cross_attention:\n",
    "                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)\n",
    "\n",
    "        if output_hidden_states:\n",
    "            all_hidden_states = all_hidden_states + (hidden_states,)\n",
    "\n",
    "        if not return_dict:\n",
    "            return tuple(\n",
    "                v\n",
    "                for v in [\n",
    "                    hidden_states,\n",
    "                    next_decoder_cache,\n",
    "                    all_hidden_states,\n",
    "                    all_self_attentions,\n",
    "                    all_cross_attentions,\n",
    "                ]\n",
    "                if v is not None\n",
    "            )\n",
    "        return BaseModelOutputWithPastAndCrossAttentions(\n",
    "            last_hidden_state=hidden_states,\n",
    "            past_key_values=next_decoder_cache,\n",
    "            hidden_states=all_hidden_states,\n",
    "            attentions=all_self_attentions,\n",
    "            cross_attentions=all_cross_attentions,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*** \n",
    "#### 2.2.1.1 BertAttention\n",
    "\n",
    "本以为 attention 的实现就在这里，没想到还要再下一层……其中，self 成员就是多头注意力的实现，而 output 成员实现 attention 后的全连接 +dropout+residual+LayerNorm 一系列操作。\n",
    "\n",
    "```\n",
    "class BertAttention(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.self = BertSelfAttention(config)\n",
    "        self.output = BertSelfOutput(config)\n",
    "        self.pruned_heads = set()\n",
    "```\n",
    "首先还是回到这一层。这里出现了上文提到的剪枝操作，即 prune_heads 方法：\n",
    "```\n",
    " def prune_heads(self, heads):\n",
    "        if len(heads) == 0:\n",
    "            return\n",
    "        heads, index = find_pruneable_heads_and_indices(\n",
    "            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads\n",
    "        )\n",
    "\n",
    "        # Prune linear layers\n",
    "        self.self.query = prune_linear_layer(self.self.query, index)\n",
    "        self.self.key = prune_linear_layer(self.self.key, index)\n",
    "        self.self.value = prune_linear_layer(self.self.value, index)\n",
    "        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)\n",
    "\n",
    "        # Update hyper params and store pruned heads\n",
    "        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)\n",
    "        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads\n",
    "        self.pruned_heads = self.pruned_heads.union(heads) \n",
    "```\n",
    "这里的具体实现概括如下：\n",
    "- `find_pruneable_heads_and_indices`是定位需要剪掉的 head，以及需要保留的维度下标 index；\n",
    "\n",
    "- `prune_linear_layer`则负责将 Wk/Wq/Wv 权重矩阵（连同 bias）中按照 index 保留没有被剪枝的维度后转移到新的矩阵。\n",
    "接下来就到重头戏——Self-Attention 的具体实现。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BertAttention(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.self = BertSelfAttention(config)\n",
    "        self.output = BertSelfOutput(config)\n",
    "        self.pruned_heads = set()\n",
    "\n",
    "    def prune_heads(self, heads):\n",
    "        if len(heads) == 0:\n",
    "            return\n",
    "        heads, index = find_pruneable_heads_and_indices(\n",
    "            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads\n",
    "        )\n",
    "\n",
    "        # Prune linear layers\n",
    "        self.self.query = prune_linear_layer(self.self.query, index)\n",
    "        self.self.key = prune_linear_layer(self.self.key, index)\n",
    "        self.self.value = prune_linear_layer(self.self.value, index)\n",
    "        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)\n",
    "\n",
    "        # Update hyper params and store pruned heads\n",
    "        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)\n",
    "        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads\n",
    "        self.pruned_heads = self.pruned_heads.union(heads)\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        hidden_states,\n",
    "        attention_mask=None,\n",
    "        head_mask=None,\n",
    "        encoder_hidden_states=None,\n",
    "        encoder_attention_mask=None,\n",
    "        past_key_value=None,\n",
    "        output_attentions=False,\n",
    "    ):\n",
    "        self_outputs = self.self(\n",
    "            hidden_states,\n",
    "            attention_mask,\n",
    "            head_mask,\n",
    "            encoder_hidden_states,\n",
    "            encoder_attention_mask,\n",
    "            past_key_value,\n",
    "            output_attentions,\n",
    "        )\n",
    "        attention_output = self.output(self_outputs[0], hidden_states)\n",
    "        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*** \n",
    "##### 2.2.1.1.1 BertSelfAttention\n",
    "\n",
    "**预警：这一块可以说是模型的核心区域，也是唯一涉及到公式的地方，所以将贴出大量代码。**\n",
    "\n",
    "初始化部分：\n",
    "```\n",
    "class BertSelfAttention(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, \"embedding_size\"):\n",
    "            raise ValueError(\n",
    "                \"The hidden size (%d) is not a multiple of the number of attention \"\n",
    "                \"heads (%d)\" % (config.hidden_size, config.num_attention_heads)\n",
    "            )\n",
    "\n",
    "        self.num_attention_heads = config.num_attention_heads\n",
    "        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)\n",
    "        self.all_head_size = self.num_attention_heads * self.attention_head_size\n",
    "\n",
    "        self.query = nn.Linear(config.hidden_size, self.all_head_size)\n",
    "        self.key = nn.Linear(config.hidden_size, self.all_head_size)\n",
    "        self.value = nn.Linear(config.hidden_size, self.all_head_size)\n",
    "\n",
    "        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)\n",
    "        self.position_embedding_type = getattr(config, \"position_embedding_type\", \"absolute\")\n",
    "        if self.position_embedding_type == \"relative_key\" or self.position_embedding_type == \"relative_key_query\":\n",
    "            self.max_position_embeddings = config.max_position_embeddings\n",
    "            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)\n",
    "\n",
    "        self.is_decoder = config.is_decoder\n",
    "```\n",
    "\n",
    "- 除掉熟悉的 query、key、value 三个权重和一个 dropout，这里还有一个谜一样的 position_embedding_type，以及 decoder 标记；\n",
    "- 注意，hidden_size 和 all_head_size 在一开始是一样的。至于为什么要看起来多此一举地设置这一个变量——显然是因为上面那个剪枝函数，剪掉几个 attention head 以后 all_head_size 自然就小了；\n",
    "\n",
    "- hidden_size 必须是 num_attention_heads 的整数倍，以 bert-base 为例，每个 attention 包含 12 个 head，hidden_size 是 768，所以每个 head 大小即 attention_head_size=768/12=64；\n",
    "\n",
    "- position_embedding_type 是什么？继续往下看就知道了.\n",
    "\n",
    "然后是重点，也就是前向传播过程。\n",
    "\n",
    "首先回顾一下 multi-head self-attention 的基本公式：\n",
    "\n",
    "$$MHA(Q, K, V) = Concat(head_1, ..., head_h)W^O$$\n",
    "$$head_i = SDPA(QW_i^Q, KW_i^K, VW_i^V)$$\n",
    "$$SDPA(Q, K, V) = softmax(\\frac{QK^T}{\\sqrt(d_k)})V$$\n",
    "\n",
    "而这些注意力头，众所周知是并行计算的，所以上面的 query、key、value 三个权重是唯一的——这并不是所有 heads 共享了权重，而是“拼接”起来了。\n",
    "\n",
    "**[原论文中多头的理由为 Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this. 而另一个比较靠谱的分析有：为什么 Transformer 需要进行 Multi-head Attention？](https://www.zhihu.com/question/341222779/answer/814111138)**\n",
    "\n",
    "看看 forward 方法：\n",
    "```\n",
    "def transpose_for_scores(self, x):\n",
    "        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)\n",
    "        x = x.view(*new_x_shape)\n",
    "        return x.permute(0, 2, 1, 3)\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        hidden_states,\n",
    "        attention_mask=None,\n",
    "        head_mask=None,\n",
    "        encoder_hidden_states=None,\n",
    "        encoder_attention_mask=None,\n",
    "        past_key_value=None,\n",
    "        output_attentions=False,\n",
    "    ):\n",
    "        mixed_query_layer = self.query(hidden_states)\n",
    "\n",
    "        # 省略一部分cross-attention的计算\n",
    "        key_layer = self.transpose_for_scores(self.key(hidden_states))\n",
    "        value_layer = self.transpose_for_scores(self.value(hidden_states))\n",
    "        query_layer = self.transpose_for_scores(mixed_query_layer)\n",
    "\n",
    "        # Take the dot product between \"query\" and \"key\" to get the raw attention scores.\n",
    "        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))\n",
    "        # ...\n",
    "```\n",
    "这里的 `transpose_for_scores` 用来把 `hidden_size` 拆成多个头输出的形状，并且将中间两维转置以进行矩阵相乘；\n",
    "\n",
    "这里 `key_layer/value_layer/query_laye`r 的形状为：(batch_size, num_attention_heads, sequence_length, attention_head_size)；\n",
    "这里 `attention_scores` 的形状为：(batch_size, num_attention_heads, sequence_length, sequence_length)，符合多个头单独计算获得的 attention map 形状。\n",
    "\n",
    "到这里实现了 K 与 Q 相乘，获得 raw attention scores 的部分，按公式接下来应该是按 $d_k$ 进行 scaling 并做 softmax 的操作。然而先出现在眼前的是一个奇怪的positional_embedding，以及一堆爱因斯坦求和：\n",
    "\n",
    "```\n",
    " # ...\n",
    "        if self.position_embedding_type == \"relative_key\" or self.position_embedding_type == \"relative_key_query\":\n",
    "            seq_length = hidden_states.size()[1]\n",
    "            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)\n",
    "            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)\n",
    "            distance = position_ids_l - position_ids_r\n",
    "            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)\n",
    "            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility\n",
    "\n",
    "            if self.position_embedding_type == \"relative_key\":\n",
    "                relative_position_scores = torch.einsum(\"bhld,lrd->bhlr\", query_layer, positional_embedding)\n",
    "                attention_scores = attention_scores + relative_position_scores\n",
    "            elif self.position_embedding_type == \"relative_key_query\":\n",
    "                relative_position_scores_query = torch.einsum(\"bhld,lrd->bhlr\", query_layer, positional_embedding)\n",
    "                relative_position_scores_key = torch.einsum(\"bhrd,lrd->bhlr\", key_layer, positional_embedding)\n",
    "                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key\n",
    "        # ...\n",
    "```\n",
    "**[关于爱因斯坦求和约定，参考以下文档：torch.einsum - PyTorch 1.8.1 documentation](https://pytorch.org/docs/stable/generated/torch.einsum.html)**\n",
    "\n",
    "\n",
    "对于不同的positional_embedding_type，有三种操作：\n",
    "\n",
    "- absolute：默认值，这部分就不用处理；\n",
    "- relative_key：对 key_layer 作处理，将其与这里的positional_embedding和 key 矩阵相乘作为 key 相关的位置编码；\n",
    "- relative_key_query：对 key 和 value 都进行相乘以作为位置编码。\n",
    "\n",
    "回到正常 attention 的流程：\n",
    "```\n",
    "# ...\n",
    "        attention_scores = attention_scores / math.sqrt(self.attention_head_size)\n",
    "        if attention_mask is not None:\n",
    "            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)\n",
    "            attention_scores = attention_scores + attention_mask  # 这里为什么是+而不是*？\n",
    "\n",
    "        # Normalize the attention scores to probabilities.\n",
    "        attention_probs = nn.Softmax(dim=-1)(attention_scores)\n",
    "\n",
    "        # This is actually dropping out entire tokens to attend to, which might\n",
    "        # seem a bit unusual, but is taken from the original Transformer paper.\n",
    "        attention_probs = self.dropout(attention_probs)\n",
    "\n",
    "        # Mask heads if we want to\n",
    "        if head_mask is not None:\n",
    "            attention_probs = attention_probs * head_mask\n",
    "\n",
    "        context_layer = torch.matmul(attention_probs, value_layer)\n",
    "\n",
    "        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n",
    "        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)\n",
    "        context_layer = context_layer.view(*new_context_layer_shape)\n",
    "\n",
    "        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)\n",
    "\n",
    "        # 省略decoder返回值部分……\n",
    "        return outputs\n",
    "```\n",
    "\n",
    "重大疑问：这里的attention_scores = attention_scores + attention_mask是在做什么？难道不应该是乘 mask 吗？\n",
    "- 因为这里的 attention_mask 已经【被动过手脚】，将原本为 1 的部分变为 0，而原本为 0 的部分（即 padding）变为一个较大的负数，这样相加就得到了一个较大的负值：\n",
    "- 至于为什么要用【一个较大的负数】？因为这样一来经过 softmax 操作以后这一项就会变成接近 0 的小数。\n",
    "\n",
    "```\n",
    "(Pdb) attention_mask\n",
    "tensor([[[[    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],\n",
    "        [[[    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],\n",
    "        [[[    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],\n",
    "        ...,\n",
    "        [[[    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],\n",
    "        [[[    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],\n",
    "        [[[    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]]],\n",
    "       device='cuda:0')\n",
    "```\n",
    "\n",
    "那么，这一步是在哪里执行的呢？\n",
    "在modeling_bert.py中没有找到答案，但是在modeling_utils.py中找到了一个特别的类：class ModuleUtilsMixin，在它的get_extended_attention_mask方法中发现了端倪：\n",
    "\n",
    "```\n",
    " def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:\n",
    "        \"\"\"\n",
    "        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.\n",
    "\n",
    "        Arguments:\n",
    "            attention_mask (:obj:`torch.Tensor`):\n",
    "                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.\n",
    "            input_shape (:obj:`Tuple[int]`):\n",
    "                The shape of the input to the model.\n",
    "            device: (:obj:`torch.device`):\n",
    "                The device of the input to the model.\n",
    "\n",
    "        Returns:\n",
    "            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.\n",
    "        \"\"\"\n",
    "        # 省略一部分……\n",
    "\n",
    "        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for\n",
    "        # masked positions, this operation will create a tensor which is 0.0 for\n",
    "        # positions we want to attend and -10000.0 for masked positions.\n",
    "        # Since we are adding it to the raw scores before the softmax, this is\n",
    "        # effectively the same as removing these entirely.\n",
    "        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility\n",
    "        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n",
    "        return extended_attention_mask\n",
    "```\n",
    "\n",
    "那么，这个函数是在什么时候被调用的呢？和BertModel有什么关系呢？\n",
    "OK，这里涉及到 `BertModel` 的继承细节了：`BertModel`继承自`BertPreTrainedModel`，后者继承自`PreTrainedModel`，而`PreTrainedModel`继承自[nn.Module, ModuleUtilsMixin, GenerationMixin]三个基类。——好复杂的封装！\n",
    "\n",
    "这也就是说，BertModel必然在中间的某个步骤对原始的attention_mask调用了get_extended_attention_mask，导致attention_mask从原始的[1, 0]变为[0, -1e4]的取值。\n",
    "\n",
    "最终在 BertModel 的前向传播过程中找到了这一调用（第 944 行）：\n",
    "\n",
    "```\n",
    "  # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n",
    "        # ourselves in which case we just need to make it broadcastable to all heads.\n",
    "        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)\n",
    "\n",
    "```\n",
    "问题解决了：这一方法不但实现了改变 mask 的值，还将其广播（broadcast）为可以直接与 attention map 相加的形状。\n",
    "不愧是你，HuggingFace。\n",
    "\n",
    "除此之外，值得注意的细节有：\n",
    "\n",
    "- 按照每个头的维度进行缩放，对于 bert-base 就是 64 的平方根即 8；\n",
    "- attention_probs 不但做了 softmax，还用了一次 dropout，这是担心 attention 矩阵太稠密吗…… 这里也提到很不寻常，但是原始 Transformer 论文就是这么做的；\n",
    "- head_mask 就是之前提到的对多头计算的 mask，如果不设置默认是全 1，在这里就不会起作用；\n",
    "- context_layer 即 attention 矩阵与 value 矩阵的乘积，原始的大小为：(batch_size, num_attention_heads, sequence_length, attention_head_size) ；\n",
    "- context_layer 进行转置和 view 操作以后，形状就恢复了(batch_size, sequence_length, hidden_size)。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BertSelfAttention(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, \"embedding_size\"):\n",
    "            raise ValueError(\n",
    "                f\"The hidden size ({config.hidden_size}) is not a multiple of the number of attention \"\n",
    "                f\"heads ({config.num_attention_heads})\"\n",
    "            )\n",
    "\n",
    "        self.num_attention_heads = config.num_attention_heads\n",
    "        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)\n",
    "        self.all_head_size = self.num_attention_heads * self.attention_head_size\n",
    "\n",
    "        self.query = nn.Linear(config.hidden_size, self.all_head_size)\n",
    "        self.key = nn.Linear(config.hidden_size, self.all_head_size)\n",
    "        self.value = nn.Linear(config.hidden_size, self.all_head_size)\n",
    "\n",
    "        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)\n",
    "        self.position_embedding_type = getattr(config, \"position_embedding_type\", \"absolute\")\n",
    "        if self.position_embedding_type == \"relative_key\" or self.position_embedding_type == \"relative_key_query\":\n",
    "            self.max_position_embeddings = config.max_position_embeddings\n",
    "            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)\n",
    "\n",
    "        self.is_decoder = config.is_decoder\n",
    "\n",
    "    def transpose_for_scores(self, x):\n",
    "        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)\n",
    "        x = x.view(*new_x_shape)\n",
    "        return x.permute(0, 2, 1, 3)\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        hidden_states,\n",
    "        attention_mask=None,\n",
    "        head_mask=None,\n",
    "        encoder_hidden_states=None,\n",
    "        encoder_attention_mask=None,\n",
    "        past_key_value=None,\n",
    "        output_attentions=False,\n",
    "    ):\n",
    "        mixed_query_layer = self.query(hidden_states)\n",
    "\n",
    "        # If this is instantiated as a cross-attention module, the keys\n",
    "        # and values come from an encoder; the attention mask needs to be\n",
    "        # such that the encoder's padding tokens are not attended to.\n",
    "        is_cross_attention = encoder_hidden_states is not None\n",
    "\n",
    "        if is_cross_attention and past_key_value is not None:\n",
    "            # reuse k,v, cross_attentions\n",
    "            key_layer = past_key_value[0]\n",
    "            value_layer = past_key_value[1]\n",
    "            attention_mask = encoder_attention_mask\n",
    "        elif is_cross_attention:\n",
    "            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))\n",
    "            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))\n",
    "            attention_mask = encoder_attention_mask\n",
    "        elif past_key_value is not None:\n",
    "            key_layer = self.transpose_for_scores(self.key(hidden_states))\n",
    "            value_layer = self.transpose_for_scores(self.value(hidden_states))\n",
    "            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)\n",
    "            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)\n",
    "        else:\n",
    "            key_layer = self.transpose_for_scores(self.key(hidden_states))\n",
    "            value_layer = self.transpose_for_scores(self.value(hidden_states))\n",
    "\n",
    "        query_layer = self.transpose_for_scores(mixed_query_layer)\n",
    "\n",
    "        if self.is_decoder:\n",
    "            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.\n",
    "            # Further calls to cross_attention layer can then reuse all cross-attention\n",
    "            # key/value_states (first \"if\" case)\n",
    "            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of\n",
    "            # all previous decoder key/value_states. Further calls to uni-directional self-attention\n",
    "            # can concat previous decoder key/value_states to current projected key/value_states (third \"elif\" case)\n",
    "            # if encoder bi-directional self-attention `past_key_value` is always `None`\n",
    "            past_key_value = (key_layer, value_layer)\n",
    "\n",
    "        # Take the dot product between \"query\" and \"key\" to get the raw attention scores.\n",
    "        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))\n",
    "\n",
    "        if self.position_embedding_type == \"relative_key\" or self.position_embedding_type == \"relative_key_query\":\n",
    "            seq_length = hidden_states.size()[1]\n",
    "            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)\n",
    "            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)\n",
    "            distance = position_ids_l - position_ids_r\n",
    "            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)\n",
    "            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility\n",
    "\n",
    "            if self.position_embedding_type == \"relative_key\":\n",
    "                relative_position_scores = torch.einsum(\"bhld,lrd->bhlr\", query_layer, positional_embedding)\n",
    "                attention_scores = attention_scores + relative_position_scores\n",
    "            elif self.position_embedding_type == \"relative_key_query\":\n",
    "                relative_position_scores_query = torch.einsum(\"bhld,lrd->bhlr\", query_layer, positional_embedding)\n",
    "                relative_position_scores_key = torch.einsum(\"bhrd,lrd->bhlr\", key_layer, positional_embedding)\n",
    "                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key\n",
    "\n",
    "        attention_scores = attention_scores / math.sqrt(self.attention_head_size)\n",
    "        if attention_mask is not None:\n",
    "            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)\n",
    "            attention_scores = attention_scores + attention_mask\n",
    "\n",
    "        # Normalize the attention scores to probabilities.\n",
    "        attention_probs = nn.Softmax(dim=-1)(attention_scores)\n",
    "\n",
    "        # This is actually dropping out entire tokens to attend to, which might\n",
    "        # seem a bit unusual, but is taken from the original Transformer paper.\n",
    "        attention_probs = self.dropout(attention_probs)\n",
    "\n",
    "        # Mask heads if we want to\n",
    "        if head_mask is not None:\n",
    "            attention_probs = attention_probs * head_mask\n",
    "\n",
    "        context_layer = torch.matmul(attention_probs, value_layer)\n",
    "\n",
    "        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n",
    "        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)\n",
    "        context_layer = context_layer.view(*new_context_layer_shape)\n",
    "\n",
    "        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)\n",
    "\n",
    "        if self.is_decoder:\n",
    "            outputs = outputs + (past_key_value,)\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*** \n",
    "##### 2.2.1.1.2 BertSelfOutput\n",
    "```\n",
    "class BertSelfOutput(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n",
    "        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n",
    "        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
    "\n",
    "    def forward(self, hidden_states, input_tensor):\n",
    "        hidden_states = self.dense(hidden_states)\n",
    "        hidden_states = self.dropout(hidden_states)\n",
    "        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n",
    "        return hidden_states\n",
    "```\n",
    "\n",
    "**这里又出现了 LayerNorm 和 Dropout 的组合，只不过这里是先 Dropout，进行残差连接后再进行 LayerNorm。至于为什么要做残差连接，最直接的目的就是降低网络层数过深带来的训练难度，对原始输入更加敏感～**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class BertSelfOutput(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n",
    "        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n",
    "        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
    "\n",
    "    def forward(self, hidden_states, input_tensor):\n",
    "        hidden_states = self.dense(hidden_states)\n",
    "        hidden_states = self.dropout(hidden_states)\n",
    "        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n",
    "        return hidden_states"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*** \n",
    "#### 2.2.1.2 BertIntermediate\n",
    "\n",
    "看完了 BertAttention，在 Attention 后面还有一个全连接+激活的操作：\n",
    "```\n",
    "class BertIntermediate(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)\n",
    "        if isinstance(config.hidden_act, str):\n",
    "            self.intermediate_act_fn = ACT2FN[config.hidden_act]\n",
    "        else:\n",
    "            self.intermediate_act_fn = config.hidden_act\n",
    "\n",
    "    def forward(self, hidden_states):\n",
    "        hidden_states = self.dense(hidden_states)\n",
    "        hidden_states = self.intermediate_act_fn(hidden_states)\n",
    "        return hidden_states\n",
    "```\n",
    "\n",
    "- 这里的全连接做了一个扩展，以 bert-base 为例，扩展维度为 3072，是原始维度 768 的 4 倍之多；\n",
    "- 这里的激活函数默认实现为 gelu（Gaussian Error Linerar Units(GELUS）当然，它是无法直接计算的，可以用一个包含tanh的表达式进行近似（略)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BertIntermediate(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)\n",
    "        if isinstance(config.hidden_act, str):\n",
    "            self.intermediate_act_fn = ACT2FN[config.hidden_act]\n",
    "        else:\n",
    "            self.intermediate_act_fn = config.hidden_act\n",
    "\n",
    "    def forward(self, hidden_states):\n",
    "        hidden_states = self.dense(hidden_states)\n",
    "        hidden_states = self.intermediate_act_fn(hidden_states)\n",
    "        return hidden_states"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*** \n",
    "#### 2.2.1.3 BertOutput\n",
    "\n",
    "在这里又是一个全连接 +dropout+LayerNorm，还有一个残差连接 residual connect：\n",
    "```\n",
    "class BertOutput(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)\n",
    "        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n",
    "        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
    "\n",
    "    def forward(self, hidden_states, input_tensor):\n",
    "        hidden_states = self.dense(hidden_states)\n",
    "        hidden_states = self.dropout(hidden_states)\n",
    "        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n",
    "        return hidden_states\n",
    "```\n",
    "\n",
    "这里的操作和 BertSelfOutput 不能说没有关系，只能说一模一样…… 非常容易混淆的两个组件。\n",
    "以下内容还包含基于 BERT 的应用模型，以及 BERT 相关的优化器和用法，将在下一篇文章作详细介绍。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BertOutput(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)\n",
    "        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n",
    "        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
    "\n",
    "    def forward(self, hidden_states, input_tensor):\n",
    "        hidden_states = self.dense(hidden_states)\n",
    "        hidden_states = self.dropout(hidden_states)\n",
    "        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n",
    "        return hidden_states"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*** \n",
    "### 2.2.3 BertPooler\n",
    "这一层只是简单地取出了句子的第一个token，即`[CLS]`对应的向量，然后过一个全连接层和一个激活函数后输出：（这一部分是可选的，因为pooling有很多不同的操作）\n",
    "\n",
    "```\n",
    "class BertPooler(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n",
    "        self.activation = nn.Tanh()\n",
    "\n",
    "    def forward(self, hidden_states):\n",
    "        # We \"pool\" the model by simply taking the hidden state corresponding\n",
    "        # to the first token.\n",
    "        first_token_tensor = hidden_states[:, 0]\n",
    "        pooled_output = self.dense(first_token_tensor)\n",
    "        pooled_output = self.activation(pooled_output)\n",
    "        return pooled_output\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input to bert pooler size: 768\n",
      "torch.Size([1, 768])\n"
     ]
    }
   ],
   "source": [
    "class BertPooler(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n",
    "        self.activation = nn.Tanh()\n",
    "\n",
    "    def forward(self, hidden_states):\n",
    "        # We \"pool\" the model by simply taking the hidden state corresponding\n",
    "        # to the first token.\n",
    "        first_token_tensor = hidden_states[:, 0]\n",
    "        pooled_output = self.dense(first_token_tensor)\n",
    "        pooled_output = self.activation(pooled_output)\n",
    "        return pooled_output\n",
    "from transformers.models.bert.configuration_bert import *\n",
    "import torch\n",
    "config = BertConfig.from_pretrained(\"bert-base-uncased\")\n",
    "bert_pooler = BertPooler(config=config)\n",
    "print(\"input to bert pooler size: {}\".format(config.hidden_size))\n",
    "batch_size = 1\n",
    "seq_len = 2\n",
    "hidden_size = 768\n",
    "x = torch.rand(batch_size, seq_len, hidden_size)\n",
    "y = bert_pooler(x)\n",
    "print(y.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小总结\n",
    "本小节对Bert模型的实现进行分析了学习，希望读者能对Bert实现有一个更为细致的把握。\n",
    "\n",
    "值得注意的是，在 HuggingFace 实现的 Bert 模型中，使用了多种节约显存的技术：\n",
    "\n",
    "- gradient checkpoint，不保留前向传播节点，只在用时计算；apply_chunking_to_forward，按多个小批量和低维度计算 FFN 部\n",
    "- BertModel 包含复杂的封装和较多的组件。以 bert-base 为例，主要组件如下：\n",
    "    - 总计Dropout出现了1+(1+1+1)x12=37次；\n",
    "    - 总计LayerNorm出现了1+(1+1)x12=25次；\n",
    "    - 总计dense全连接层出现了(1+1+1)x12+1=37次，并不是每个dense都配了激活函数……\n",
    "BertModel 有极大的参数量。以 bert-base 为例，其参数量为 109M。\n",
    "\n",
    "## 致谢\n",
    "本文主要由浙江大学李泺秋撰写，本项目同学负责整理和汇总。"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "3bfce0b4c492a35815b5705a19fe374a7eea0baaa08b34d90450caf1fe9ce20b"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit ('venv': virtualenv)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}